package graph;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

/**
 * 获取一个连通图上的点到一个特定的点的最小距离，
 * 限制：联通图上，不能存在有环的距离和为负数的情况，否则会一直在环处旋转，死循环
 *
 * @author Liaorun
 */
public class Dijkstra {


    public static HashMap<Node, Integer> dijkstra(Node head) {
        // head 出发到所有点的最小距离
        // key: 从head出发到达key
        // value: 从head出发到达key的最小距离
        // 如果在表中，没有T的记录，含义是从head出发到T这个点的距离为正无穷
        HashMap<Node, Integer> distanceMap = new HashMap<>(1 << 4);

        distanceMap.put(head, 0);

        // 已经求过距离的节点,存在selectedNodes中，以后在也不碰
        HashSet<Node> selectedNodes = new HashSet<>();

        // 查找一个到头节点最近的节点， 这里肯定获得是头节点，因为，selectedNodes为空，distanceMap里面只添加了头节点
        Node minNode = getMinDistanceAndUnselectedNode(distanceMap, selectedNodes);

        // 跳出条件：连通图的所有的点都处理了，都记录在selectedNodes 中了
        while (minNode != null) {
            int distanceToHead = distanceMap.get(minNode);

            // 遍历和当前点连接的点
            for (Edge edge : minNode.edges) {

                // 和当前点连接的点
                Node toNode = edge.to;


                if (!distanceMap.containsKey(toNode)) {
                    // 头节点到这个点的距离为 +finity,可以更新距离
                    distanceMap.put(toNode, distanceToHead + edge.weight);
                } else {
                    // if((当前点到头节点的距离 + 当前点到邻接点的距离) < 原来的路径到下一个节点的距离）
                    // {找到一个更好的路径，更新距离}
                    distanceMap.put(toNode, Math.min(distanceMap.get(toNode), distanceToHead + edge.weight));
                }
            }

            // 记录当前点已经来过了
            selectedNodes.add(minNode);
            // 查找下一个到头节点最近的节点
            minNode = getMinDistanceAndUnselectedNode(distanceMap, selectedNodes);
        }

        // 返回每个节点到头节点的最短距离
        return distanceMap;
    }

    /**
     * 在distanceMap 中选一个到头节点最近的点，且这个点没有被选择过
     *
     * @param distanceMap   点到头节点的距离map
     * @param selectedNodes 已经选择过的点
     * @return 一个到头节点最近的点，且这个点没有被选择过 or null
     */
    private static Node getMinDistanceAndUnselectedNode(HashMap<Node, Integer> distanceMap, HashSet<Node> selectedNodes) {
        Node minNode = null;
        int minDistance = Integer.MAX_VALUE;
        for (Map.Entry<Node, Integer> entry : distanceMap.entrySet()) {
            Node node = entry.getKey();
            Integer distance = entry.getValue();
            // 首先判断是没有选择过的点，如果是 判断是否距离是当前最小距离
            if (!selectedNodes.contains(node) && distance < minDistance) {
                // 满足条件，记录这个点
                minNode = node;
                // 刷新最小距离
                minDistance = distance;
            }
        }

        return minNode;
    }

    public static class NodeRecord {
        public Node node;
        public int distance;

        public NodeRecord(Node node, int distance) {
            this.node = node;
            this.distance = distance;
        }
    }

    public static class NodeHeap {
        private Node[] nodes;
        private HashMap<Node, Integer> heapIndexMap;
        private HashMap<Node, Integer> distanceMap;
        private int size;

        public NodeHeap(int size) {
            nodes = new Node[size];
            heapIndexMap = new HashMap<>();
            distanceMap = new HashMap<>();
            this.size = 0;
        }

        public void addOrUpdateOrIgnore(Node node, int distance) {
            if (inHeap(node)) {
                // 是添加过的节点
                // 更新距离
                distanceMap.put(node, Math.min(distanceMap.get(node), distance));
                // 向上浮重新堆化
                insertHeapify(heapIndexMap.get(node));
            }

            if (!isEntered(node)) {
                // 是没添加过的节点

                // 添加节点在末尾
                nodes[size] = node;
                // 记录节点位置
                heapIndexMap.put(node, size);
                // 记录开始节点到对应节点的最短距离
                distanceMap.put(node, distance);
                // 向上浮重新堆化,然后size++
                insertHeapify(size++);
            }
        }

        /**
         * 该节点是否进来过堆
         *
         * @param node 节点
         * @return true 进入过
         */
        private boolean isEntered(Node node) {
            return heapIndexMap.containsKey(node);
        }

        /**
         * 向上堆化
         *
         * @param index 开始节点
         */
        private void insertHeapify(Integer index) {
            while (distanceMap.get(nodes[index]) < distanceMap.get(nodes[(index - 1) / 2])) {
                swap(index, (index - 1) / 2);
                index = (index - 1) / 2;
            }
        }

        /**
         * 交换两个节点的位置
         *
         * @param a a位置
         * @param b b位置
         */
        private void swap(Integer a, int b) {
            heapIndexMap.put(nodes[a], b);
            heapIndexMap.put(nodes[b], a);
            Node tmp = nodes[a];
            nodes[a] = nodes[b];
            nodes[b] = tmp;
        }

        /**
         * 判断节点是否在堆上
         *
         * @param node 节点
         * @return 节点是否在堆上
         */
        private boolean inHeap(Node node) {
            // 在记录位置的map中存在且值不为 -1
            return isEntered(node) && heapIndexMap.get(node) != -1;
        }

        public boolean isEmpty() {
            return size == 0;
        }

        /**
         * 弹出堆顶元素并堆化
         *
         * @return 堆顶元素
         */
        public NodeRecord pop() {

            NodeRecord record = new NodeRecord(nodes[0], distanceMap.get(nodes[0]));

            // 交换栈顶和末尾的节点
            swap(0, size - 1);

            // 弹出的节点位置记录不删除，而是变成-1，表示弹出了
            heapIndexMap.put(nodes[size - 1], -1);
            // 删除该点的距离信息
            distanceMap.remove(nodes[size - 1]);
            // 删除指向原栈顶元素的指针
            nodes[size - 1] = null;
            // 向下堆化
            heapify(0, --size);

            return record;
        }

        /**
         * 向下堆化
         *
         * @param index 开始节点
         * @param size  堆的大小
         */
        private void heapify(int index, int size) {
            // 计算左子节点的位置
            int left = index * 2 + 1;
            while (left < size) {
                int smallest = left + 1 < size && distanceMap.get(nodes[left + 1]) < distanceMap.get(nodes[left]) ? left + 1 : left;
                smallest = distanceMap.get(nodes[smallest]) < distanceMap.get(index) ? smallest : index;
                // 子节点比父节点大结束
                if (smallest == index) {
                    break;
                }
                // 子节点比父节点小，交换，下一轮
                swap(smallest, index);
                // 小的子节点变成下一个处理的节点
                index = smallest;
                // 更新左子节点的位置
                left = index * 2 + 1;
            }
        }
    }


    /**
     * 获取一个连通图上的所有点到一个特定的点的最小距离，
     *
     * @param head 特定的点
     * @param size 图上点的数目
     * @return 每个可到达的点的最短距离
     */
    public static HashMap<Node, Integer> dijstra2(Node head, int size) {
        NodeHeap nodeHeap = new NodeHeap(size);

        // 添加起点
        nodeHeap.addOrUpdateOrIgnore(head, 0);
        // 记录结果的无序map
        HashMap<Node, Integer> result = new HashMap<>(size);


        if (!nodeHeap.isEmpty()) {
            // 弹出小根堆的栈顶节点
            NodeRecord record = nodeHeap.pop();
            Node cur = record.node;
            int distance = record.distance;
            // 和当前节点相连的节点的距离都更新或添加
            for (Edge edge : cur.edges) {
                nodeHeap.addOrUpdateOrIgnore(edge.to, edge.weight + distance);
            }


            // 这个算法可以保证每次弹出的节点的距离就是到起点的最短距离
            // 找不到更短的了
            result.put(cur, distance);
        }

        return result;
    }
}
